{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Manilulating Data Distribution\n",
    "\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HanXudong/fairlib/blob/main/tutorial/manipulate_data_distribution.ipynb)\n",
    "\n",
    "In this notebook, we will introduce the built-in function for manipulating data distributions.\n",
    "\n",
    "Overall, the process can be described as:\n",
    "1. Load data from files\n",
    "2. Analysis of the loaded dataset distribution\n",
    "3. Resample instances based on their target labels and protected labels\n",
    "\n",
    "Implementations for steps 1 and 4 are the same as others. Please see [dataloader](https://hanxudong.github.io/fairlib/reference_api_dataloaders.html) for detailed description."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load data\n",
    "\n",
    "Similar to long-tail learning, the data distribution is essentially the label distribution. However, besides the target label in long-tail learning, we also need to consider the demographic labels.\n",
    "\n",
    "Since the input distribution is not required for manipulating data distributions, we just create synthetic labels for demonstration purposes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "y = []\n",
    "g = []\n",
    "for _y, _g, _n in zip(\n",
    "    [1,1,1,0,0,0], # Target labels, y\n",
    "    [2,1,0,1,2,0], # Protected labels, g\n",
    "    [4,5,6,3,7,9]  # Number of instances with target label _y and group label _g\n",
    "    ):\n",
    "    y = y + [_y]*_n\n",
    "    g = g + [_g]*_n\n",
    "\n",
    "y = np.array(y)\n",
    "g = np.array(g)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Analysis of the loaded dataset distribution\n",
    "\n",
    "Essentially, we are interested in the joint distribution of target label *y* and protected label *g*, which are both discrete random variables.\n",
    "\n",
    "We save the corresponding results in probability tables:\n",
    "\n",
    "Given target labels and protected labels, calculate empirical distributions.\n",
    "\n",
    "- joint_dist: n_class * n_groups matrix, where each element refers to the joint probability, i.e., proportion size.\n",
    "- g_dist: n_groups array, indicating the prob of each group\n",
    "- y_dist: n_class array, indicating the prob of each class\n",
    "- g_cond_y_dist: n_class * n_groups matrix, g_cond_y_dit[y_id,:] refers to the group distribution within class y_id\n",
    "- y_cond_g_dist: n_class * n_groups matrix, y_cond_g_dit[:,g_id] refers to the class distribution within group g_id"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairlib.src.dataloaders.generalized_BT import get_data_distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['joint_dist', 'g_dist', 'y_dist', 'g_cond_y_dist', 'y_cond_g_dist', 'yg_index', 'N'])\n"
     ]
    }
   ],
   "source": [
    "synthetic_data_distribution = get_data_distribution(y_data = y, g_data = g)\n",
    "print(synthetic_data_distribution.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0.26470588, 0.08823529, 0.20588235],\n",
       "       [0.17647059, 0.14705882, 0.11764706]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# n_class * n_groups matrix, where each element refers to the joint probability\n",
    "synthetic_data_distribution[\"joint_dist\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Resample instances based on their target labels and protected labels\n",
    "\n",
    "In order to specify a target distribution of *y* and *g*, there are 5 options in the `generalized_sampling` function, which can be classified into 3 types as:\n",
    "1. specify the target joint distribution of *y* and *g*, i.e., *p(y,g)*.\n",
    "2. specify the y_dist (*p(y)*) and g_cond_y_dist (*p(g|y)*), and *p(y,g) = p(g|y)p(y)*\n",
    "3. specify the g_dist (*p(g)*) and y_cond_g_dist (*p(y|g)*), and *p(y,g) = p(y|g)p(g)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairlib.src.dataloaders.generalized_BT import generalized_sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Taking the group label distribution as an example, we would like the group label is uniformaly distributed, i.e., *p(g=0)=p(g=1)=p(g=2)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.33333333, 0.33333333, 0.33333333])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_g_dist = np.ones(3)/3\n",
    "target_g_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[25, 26, 27, 28, 15, 16, 17, 18, 19, 20, 21, 9, 10, 11, 4, 5, 6, 7, 0, 1]\n"
     ]
    }
   ],
   "source": [
    "balanced_g_indices = generalized_sampling(\n",
    "    default_distribution_dict = synthetic_data_distribution,\n",
    "    N = 20,\n",
    "    g_dist=target_g_dist,)\n",
    "print(balanced_g_indices)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "balanced_g_distribution = get_data_distribution(\n",
    "    y_data = y[balanced_g_indices], \n",
    "    g_data = g[balanced_g_indices])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'joint_dist': array([[0.2 , 0.15, 0.2 ],\n",
       "        [0.15, 0.2 , 0.1 ]]),\n",
       " 'g_dist': array([0.35, 0.35, 0.3 ]),\n",
       " 'y_dist': array([0.55, 0.45]),\n",
       " 'g_cond_y_dist': array([[0.36363636, 0.27272727, 0.36363636],\n",
       "        [0.33333333, 0.44444444, 0.22222222]]),\n",
       " 'y_cond_g_dist': array([[0.57142857, 0.42857143, 0.66666667],\n",
       "        [0.42857143, 0.57142857, 0.33333333]]),\n",
       " 'yg_index': {(0, 0): array([0, 1, 2, 3], dtype=int64),\n",
       "  (0, 1): array([4, 5, 6], dtype=int64),\n",
       "  (0, 2): array([ 7,  8,  9, 10], dtype=int64),\n",
       "  (1, 0): array([11, 12, 13], dtype=int64),\n",
       "  (1, 1): array([14, 15, 16, 17], dtype=int64),\n",
       "  (1, 2): array([18, 19], dtype=int64)},\n",
       " 'N': 20.0}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "balanced_g_distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Theoretically, we could sample any data distributions with the `generalized_sampling` function. But it is inconvenient to specify the distribution array every time. We further provide the `manipulate_data_distribution` function to simplify the manipulation process.\n",
    "\n",
    "Specifically, the manipulation is based on the interpolation between the original data distribution and a specifically balanced target distribution. Essentially, these balanced target distributions are identical to [**BT**](https://hanxudong.github.io/fairlib/reference_api_debiasing/BT.html).\n",
    "\n",
    "The *alpha* refers to the interpolation extent as $final\\_dist = \\alpha*balanced\\_dist + (1-\\alpha)*origianl\\_dist$\n",
    "\n",
    "- default_distribution_dict (dict): a dict of distribution information of the original dataset.\n",
    "- N (int, optional): The total number of returned indices. Defaults to None.\n",
    "- GBTObj (str, optional): original | joint | g | y | g_cond_y | y_cond_g. Defaults to \"original\".\n",
    "- alpha (int, optional): interpolation between the original distribution and the target distribution. Defaults to 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from fairlib.src.dataloaders.generalized_BT import manipulate_data_distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following examples show to what extent can we manipulate the distribution by tuning the alpha value\n",
    "- aphla > 1: anti-imbalance.\n",
    "- aphla = 1: perfectly balanced distribution as corresponding BT objectives.\n",
    "- 0 < aphla < 1: interpolation between original distribution and the perfectly balanced distribution.\n",
    "- alpha = 0: original distribution.\n",
    "- alpha < 0: amplify imbalance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_indices_10 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =1)\n",
    "g_indices_00 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =0)\n",
    "g_indices_05 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =0.5)\n",
    "g_indices_n10 = manipulate_data_distribution(default_distribution_dict = synthetic_data_distribution, N = 20, GBTObj = \"g\", alpha =-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original: [0.44117647 0.23529412 0.32352941]\n",
      "Balanced g: [0.35 0.35 0.3 ]\n",
      "aplha 0.0: [0.45 0.25 0.3 ]\n",
      "aplha 0.5: [0.4 0.3 0.3]\n",
      "aplha 1.0: [0.35 0.35 0.3 ]\n",
      "aplha -1.0: [0.55 0.15 0.3 ]\n"
     ]
    }
   ],
   "source": [
    "print(\"Original:\", synthetic_data_distribution[\"g_dist\"])\n",
    "print(\"Balanced g:\", get_data_distribution(y_data = y[balanced_g_indices], g_data = g[balanced_g_indices])[\"g_dist\"])\n",
    "print(\"aplha 0.0:\", get_data_distribution(y_data = y[g_indices_00], g_data = g[g_indices_00])[\"g_dist\"])\n",
    "print(\"aplha 0.5:\", get_data_distribution(y_data = y[g_indices_05], g_data = g[g_indices_05])[\"g_dist\"])\n",
    "print(\"aplha 1.0:\", get_data_distribution(y_data = y[g_indices_10], g_data = g[g_indices_10])[\"g_dist\"])\n",
    "print(\"aplha -1.0:\", get_data_distribution(y_data = y[g_indices_n10], g_data = g[g_indices_n10])[\"g_dist\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Limitation and Extension\n",
    "\n",
    "The `manipulate_data_distribution` function has a strong assumption that the original distribution is **NOT** identical to the target perfectly balanced distribution. For those who have a perfectly balanced dataset, the target distributions will need to the manually specified and use the `generalized_sampling` for resampling."
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "eb4eca40f7710e7c7146430b0424b585ee7a07b7e7498958627feae1c8ad8261"
  },
  "kernelspec": {
   "display_name": "Python 3.7.12 ('py37')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
